import { _mergeDicts } from "./base.js";
import type { MessageOutputVersion } from "./message.js";

export type ResponseMetadata = {
  model_provider?: string;
  model_name?: string;
  output_version?: MessageOutputVersion;
  [key: string]: unknown;
};

export function mergeResponseMetadata(
  a?: ResponseMetadata,
  b?: ResponseMetadata
): ResponseMetadata {
  const output: ResponseMetadata = _mergeDicts(a ?? {}, b ?? {});
  return output;
}

export type ModalitiesTokenDetails = {
  /**
   * Text tokens.
   * Does not need to be reported, but some models will do so.
   */
  text?: number;

  /**
   * Image (non-video) tokens.
   */
  image?: number;

  /**
   * Audio tokens.
   */
  audio?: number;

  /**
   * Video tokens.
   */
  video?: number;

  /**
   * Document tokens.
   * e.g. PDF
   */
  document?: number;
};

function mergeModalitiesTokenDetails(
  a?: ModalitiesTokenDetails,
  b?: ModalitiesTokenDetails
): ModalitiesTokenDetails {
  const output: ModalitiesTokenDetails = {};
  if (a?.audio !== undefined || b?.audio !== undefined) {
    output.audio = (a?.audio ?? 0) + (b?.audio ?? 0);
  }
  if (a?.image !== undefined || b?.image !== undefined) {
    output.image = (a?.image ?? 0) + (b?.image ?? 0);
  }
  if (a?.video !== undefined || b?.video !== undefined) {
    output.video = (a?.video ?? 0) + (b?.video ?? 0);
  }
  if (a?.document !== undefined || b?.document !== undefined) {
    output.document = (a?.document ?? 0) + (b?.document ?? 0);
  }
  if (a?.text !== undefined || b?.text !== undefined) {
    output.text = (a?.text ?? 0) + (b?.text ?? 0);
  }
  return output;
}

/**
 * Breakdown of input token counts.
 *
 * Does not *need* to sum to full input token count. Does *not* need to have all keys.
 */
export type InputTokenDetails = ModalitiesTokenDetails & {
  /**
   * Input tokens that were cached and there was a cache hit.
   *
   * Since there was a cache hit, the tokens were read from the cache.
   * More precisely, the model state given these tokens was read from the cache.
   */
  cache_read?: number;

  /**
   * Input tokens that were cached and there was a cache miss.
   *
   * Since there was a cache miss, the cache was created from these tokens.
   */
  cache_creation?: number;
};

function mergeInputTokenDetails(
  a?: InputTokenDetails,
  b?: InputTokenDetails
): InputTokenDetails {
  const output: InputTokenDetails = {
    ...mergeModalitiesTokenDetails(a, b),
  };
  if (a?.cache_read !== undefined || b?.cache_read !== undefined) {
    output.cache_read = (a?.cache_read ?? 0) + (b?.cache_read ?? 0);
  }
  if (a?.cache_creation !== undefined || b?.cache_creation !== undefined) {
    output.cache_creation = (a?.cache_creation ?? 0) + (b?.cache_creation ?? 0);
  }
  return output;
}

/**
 * Breakdown of output token counts.
 *
 * Does *not* need to sum to full output token count. Does *not* need to have all keys.
 */
export type OutputTokenDetails = ModalitiesTokenDetails & {
  /**
   * Reasoning output tokens.
   *
   * Tokens generated by the model in a chain of thought process (i.e. by
   * OpenAI's o1 models) that are not returned as part of model output.
   */
  reasoning?: number;
};

function mergeOutputTokenDetails(
  a?: OutputTokenDetails,
  b?: OutputTokenDetails
): OutputTokenDetails {
  const output: OutputTokenDetails = {
    ...mergeModalitiesTokenDetails(a, b),
  };
  if (a?.reasoning !== undefined || b?.reasoning !== undefined) {
    output.reasoning = (a?.reasoning ?? 0) + (b?.reasoning ?? 0);
  }
  return output;
}

/**
 * Usage metadata for a message, such as token counts.
 */
export type UsageMetadata = {
  /**
   * Count of input (or prompt) tokens. Sum of all input token types.
   */
  input_tokens: number;
  /**
   * Count of output (or completion) tokens. Sum of all output token types.
   */
  output_tokens: number;
  /**
   * Total token count. Sum of input_tokens + output_tokens.
   */
  total_tokens: number;

  /**
   * Breakdown of input token counts.
   *
   * Does *not* need to sum to full input token count. Does *not* need to have all keys.
   */
  input_token_details?: InputTokenDetails;

  /**
   * Breakdown of output token counts.
   *
   * Does *not* need to sum to full output token count. Does *not* need to have all keys.
   */
  output_token_details?: OutputTokenDetails;
};

export function mergeUsageMetadata(
  a?: UsageMetadata,
  b?: UsageMetadata
): UsageMetadata {
  return {
    input_tokens: (a?.input_tokens ?? 0) + (b?.input_tokens ?? 0),
    output_tokens: (a?.output_tokens ?? 0) + (b?.output_tokens ?? 0),
    total_tokens: (a?.total_tokens ?? 0) + (b?.total_tokens ?? 0),
    input_token_details: mergeInputTokenDetails(
      a?.input_token_details,
      b?.input_token_details
    ),
    output_token_details: mergeOutputTokenDetails(
      a?.output_token_details,
      b?.output_token_details
    ),
  };
}
